import sys

sys.path.insert(0, '../src')

import torch
from inets import VideoTransFG, transfg_config, video_config


if __name__ == '__main__':
    imgcfg = transfg_config('ViT-B_16')
    vidcfg = video_config()

    net = VideoTransFG(imgcfg, vidcfg, 224, 200, 16)

    frame = torch.randn((5, 3*3, 224, 224))
    y = net(frame)

    print(y['logits'].shape)
    print(y['top_sel']['select'])
    print(y['sub_sel']['select'])
